"""
PyTorch implementation of the DADA optimizers
"""
import logging

import torch
from torch.optim import Optimizer

logger = logging.getLogger(__name__)


class WDA(Optimizer):
    __version__ = '1.0.0'

    def __init__(self, params, d0 = 1):
        """
        Implements Weighted Dual Averaging (WDA) as a PyTorch Optimizer.

        Args:
            params: Iterable of parameters to optimize or dicts defining parameter groups.
            d0: Distance between the inital point and the optimal point.
        """
        self.__class__.__name__ = "WDA"

        defaults = dict(d0=d0)
        super(WDA, self).__init__(params, defaults)

        # Save the initial values of parameters
        for group in self.param_groups:
            for param in group['params']:
                if param.requires_grad:
                    # Save the initial parameter values in the state
                    self.state[param] = {
                        'grad_avg': torch.zeros_like(param),
                        'init_point': param.clone().detach()  # Save the initial point
                    }

    def __setstate__(self, state):
        super(WDA, self).__setstate__(state)

    @torch.no_grad()
    def step(self, closure=None):
        """
        Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        for group in self.param_groups:
            self._update_group_state(group)
            for p in group['params']:
                if p.grad is None:
                    continue
                else:
                    p.data = self.state[p]['init_point'] - self.state[p]['eta']

        return loss

    def _update_group_state(self, group):
        k = group.get('step', 0) + 1
        group['step'] = k
        eta_scale = (1 / torch.sqrt(torch.tensor(k)))

        # Update grad_avg as a tensor
        for p in group['params']:
            g = p.grad.detach()
            self.state[p]['grad_avg'] += (group['d0'] / g.norm()) * g
            self.state[p]['eta'] = eta_scale * self.state[p]['grad_avg']

    def has_d_estimator(self):
        return False
